In [1]:
import os
# sorry my 0 gpu is busy
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
In [2]:
import torch
import jupytertracerviz
from torchvision.models import resnet18
model = resnet18().cuda()
inputs = [torch.randn((5, 3, 224, 224), device='cuda') for _ in range(10)]
model_c = torch.compile(model)
def fwd_bwd(inp):
out = model_c(inp)
out.sum().backward()
# warm up
fwd_bwd(inputs[0])
with torch.profiler.profile() as prof:
for i in range(1, 4):
fwd_bwd(inputs[i])
prof.step()
prof.export_chrome_trace("trace.json")
jupytertracerviz.visualize("trace.json", height = "800")
Dumping trace data, total entries: 9655
In [ ]: